#from datetimei import date, datetime
import matplotlib.pyplot as plt
#from torchdiffeq import odeint
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchsde
import torch.nn.functional as F
import torch.optim as optim
# from loguru import logger
# from scipy.integrate import odeint
from torch.autograd import grad
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split
from sklearn.model_selection import train_test_split
import argparse
import csv
import random
import time

class NeuralSDE(nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"
    def __init__(self, input_dim):
        super(NeuralSDE, self).__init__()
        #self.drift = Drift(input_dim, torch.randn(input_dim))
        self.drift = ComplexDrift(input_dim)
        #self.diffusion = Diffusion(input_dim, torch.randn(input_dim))
        self.diffusion = ComplexDiffusion(input_dim)
    def f(self, t, S):
        return self.drift(t, S)
    def g(self, t, S):
        return self.diffusion(t, S)

class Drift(nn.Module):
    def __init__(self, input_dim, mu):
        super(Drift, self).__init__()
        self.mu = nn.Parameter(torch.tensor(mu, dtype=torch.float32))
    def forward(self, t, S):
        return self.mu * S

class Diffusion(nn.Module):
    def __init__(self, input_dim, sigma):
        super(Diffusion, self).__init__()
        self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32))
    def forward(self, t, S):
        return self.sigma * S

class ComplexDiffusion(nn.Module):
    def __init__(self, input_dim):
        super(ComplexDiffusion, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, t, S):
        return self.network(S)
    
class ComplexDrift(nn.Module):
    def __init__(self, input_dim):
        super(ComplexDrift, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, t, S):
        return self.network(S)
    
class GBM(nn.Module):
    def __init__(self, input_dim, mu, sigma):
        super(GBM, self).__init__()
        self.mu = nn.Parameter(torch.tensor(mu, dtype=torch.float32))
        self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32))
    # Drift function
    def f(self, t, S):
        return self.mu * S
    # Diffusion function
    def g(self, t, S):
        return self.sigma * S

# Define the SDE class with required noise_type and sde_type attributes
class SDE(nn.Module):
    noise_type = "diagonal"  # This specifies the type of noise
    sde_type = "ito"  # This specifies the type of SDE (Itô or Stratonovich)
    def __init__(self, input_dim, mu, sigma):
        super(SDE, self).__init__()
        self.drift = GBM(input_dim, mu, sigma)
    def f(self, t, S):
        return self.drift.f(t, S)
    def g(self, t, S):
        return self.drift.g(t, S)
    
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PI_ROPF')
    ### NN HYPERPARAMS
    parser.add_argument('--seed', type=int, default=123, help='random seed')
    parser.add_argument('--nHiddenUnit', type=int, default = 50, help='number of hidden units')
    parser.add_argument('--activation', type=str, default = "RELU", help='activation_function')
    parser.add_argument('--optimizer', type=int, default = 2, help='GD algorithm')
    parser.add_argument('--lr', type=float, default = 1e-3, help='total number of datapoints')
    parser.add_argument('--batchsize', type=int, default = 50, help='training batch size')
    parser.add_argument('--train_test_and_valid_split', type=float, default = .2)
    parser.add_argument('--normalize', type=bool, default = False)
    parser.add_argument('--max_epochs', type=int, default = 700, help='max training epochs')
    parser.add_argument('--k', type=float, default = .1)
    parser.add_argument('--k_check', type=float, default = .5)
    parser.add_argument('--lambdA_train', type=float, default = .5)

    parser.add_argument('--delta_k', type=float, default = .025)
    parser.add_argument('--delta_k_epoch', type=int, default = 100)
    parser.add_argument('--delta_k_split', type=float, default = 1)

    parser.add_argument('--mu_lb', type=float, default = .1)
    parser.add_argument('--mu_ub', type=float, default = 1)
    parser.add_argument('--sigma_lb', type=float, default = .01)
    parser.add_argument('--sigma_ub', type=float, default = .5)

    parser.add_argument('--max_patience', type=int, default = 5)
    parser.add_argument('--nSamples', type=int, default = 10000, help='number of layers')
    parser.add_argument('--nLayer', type=int, default = 5, help='number of layers')
    parser.add_argument('--id', type=int, default = 5, help='number of layers')

    args = parser.parse_args()
    args = vars(args) # change to dictionary

    set_seed(args['seed'])
    batch_size = args['batchsize']
    k = args['k']
    delta_k_epoch = args['delta_k_epoch']
    delta_k = args['delta_k']
    # Number of assets
    num_assets = 50

    # arch_list = [args['nHiddenUnit']] * args['nLayer']
    # arch_list.insert(0,num_assets) 
    # arch_list.append(num_assets)
    # Parameters
    # Define drift (mu) and volatility (sigma) for each asset

    mu = np.random.uniform(args['mu_lb'], args['mu_ub'], num_assets)
    sigma = np.random.uniform(args['sigma_lb'], args['sigma_ub'], num_assets)
    #S0 = np.random.uniform(0, 1, num_assets) # Initial prices
    
    asset_price_training = np.load(f'portfolio_data/asset_prices_training_{int(args['nSamples']*.8)}.npy').T
    asset_price_validation = np.load(f'portfolio_data/asset_prices_validation_{int(args['nSamples']*.1)}.npy').T
    asset_price_test = np.load(f'portfolio_data/asset_prices_test_{int(args['nSamples']*.1)}.npy').T

    T = 1.0
    dt = 1/252
    num_paths = args['nSamples']
    time_points = torch.linspace(0, T, int(T/dt) + 1)

    # Convert parameters to tensors
    mu_torch = torch.tensor(mu, dtype=torch.float32)
    sigma_torch = torch.tensor(sigma, dtype=torch.float32)
    #S0_torch = torch.tensor(S0, dtype=torch.float32).unsqueeze(0) ### Steady State asset price
    #S0_extended_torch = torch.stack([mu_torch, sigma_torch, S0_torch.squeeze(0)], dim=1)
    # Initialize the model
    sde_model = SDE(num_assets, mu_torch, sigma_torch)

    # Simulate asset dynamics : the idea is that, given a range of s(0), the price will vary following 
    # based on some SDE dynamics, with fixed parameter, but a stochastic behavior.
    # The training aims to instruct the model to predict a future price trend.

    price_paths = []
    price_paths_extended = []
    x = []
    index = 0

    for index in range(num_paths):
        if index < int(args['nSamples']*.8):
            S0 = asset_price_training[index, :num_assets] #np.random.uniform(0, 1, num_assets)
        elif index >= int(args['nSamples']*.8) and index < int(args['nSamples']*.9):
            S0 = asset_price_validation[index-int(args['nSamples']*.8), :num_assets]
        else:
            S0 = asset_price_test[index-int(args['nSamples']*.9), :num_assets]
        S0_torch = torch.tensor(S0, dtype=torch.float32).unsqueeze(0) ### Steady State asset price
        # mu = np.random.uniform(0.01, 0.1, num_assets)
        # sigma = np.random.uniform(0.1, 0.3, num_assets)
        # mu_torch = torch.tensor(mu, dtype=torch.float32)
        # sigma_torch = torch.tensor(sigma, dtype=torch.float32)
        sde_model = SDE(num_assets, mu_torch, sigma_torch)
        x.append(torch.squeeze(S0_torch))
        # S0_extended_torch = torch.stack([S0_torch.squeeze(0), mu_torch, sigma_torch], dim=1)
        # init_cond_extended.append(S0_extended_torch.squeeze(0))
        #set_seed(args['seed'])
        S_paths = torch.squeeze(torchsde.sdeint(sde_model, S0_torch, time_points, method='euler')) #.squeeze(0).squeeze(0)
        mu_path = mu_torch.repeat(S_paths.size(0))
        sigma_path = sigma_torch.repeat(S_paths.size(0))
        price_paths.append(S_paths.detach().numpy()) 
        # price_paths_extended.append(torch.stack([S_paths, mu_path, sigma_path], dim=1).detach().numpy())
        #print(S_paths.squeeze(0).detach().numpy()[:10])
        #print(index)
        index += 1
        tmp = S_paths.squeeze(0).detach().numpy()
        plt.plot(time_points.numpy(),tmp[:,np.random.randint(1,tmp.shape[1])])
        plt.show()
        
    price_paths = np.stack(price_paths, axis=-1)
    price_paths_torch = torch.tensor(price_paths, dtype=torch.float32).permute(2, 0, 1)#.squeeze(1).permute(2, 0, 1)

    #price_paths_extended = np.stack(price_paths_extended, axis=-1)
    #price_paths_extended_torch = torch.tensor(price_paths_extended, dtype=torch.float32).permute(2, 1, 0)#.squeeze(1).permute(2, 0, 1)

    #S0_torch = S0_torch.squeeze(0).repeat(num_paths,1) #.permute(0, 1)
    #init_cond_extended = torch.from_numpy(np.array(init_cond_extended))

    init_cond = torch.from_numpy(np.array(x))

    X_train, X_valid , Y_train, Y_valid = train_test_split(init_cond, price_paths_torch , test_size=args['train_test_and_valid_split'], random_state=1)
    X_valid, X_test , Y_valid, Y_test = train_test_split(X_valid, Y_valid, test_size=0.5) #, random_state=1)

    #X_train, X_test , Y_train, Y_test = train_test_split(init_cond_extended, price_paths_extended_torch , test_size=0.2, random_state=1)
    
    train_data = TensorDataset(X_train, Y_train)   # X:(1024,2) Y:(1024)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    valid_data = TensorDataset(X_valid, Y_valid)   # X:(1024,2) Y:(1024)
    valid_loader = DataLoader(valid_data, batch_size=len(valid_data), shuffle=False)

    test_data = TensorDataset(X_test, Y_test)   # X:(1024,2) Y:(1024)
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)

    # Initialize model, optimizer, and loss function
    input_dim = num_assets
    model = NeuralSDE(input_dim)

    if args['optimizer']==1:
        optimizer = torch.optim.Adam(model.parameters(), lr = args['lr'])
    elif args['optimizer']==2:
        optimizer = torch.optim.Adadelta(model.parameters(), lr = args['lr'])
    elif args['optimizer']==3:
        optimizer = torch.optim.SGD(model.parameters(), lr = args['lr'])
    else:
        optimizer = torch.optim.LBFGS(model.parameters(), lr = args['lr'])

    loss_fn = nn.MSELoss()
    
    # penalty term multiplier
    lambdA = args['lambdA_train']
    # early stopping parameters
    patience = 0
    min_loss = 100000
    max_patience = args['max_patience']

    for epoch in range(args['max_epochs']):
        for i, (x, y) in enumerate(train_loader):
            optimizer.zero_grad()
            #y = y.permute(1,0)
            #print( torch.unsqueeze(x,1).size())
            S_pred = torch.squeeze(torchsde.sdeint(model, x, time_points[:int(k*len(time_points)-1)])).permute(1,0,2)
            #print( y[:,:,:int(k*len(t))].permute(2,0,1).size())
            loss = loss_fn(S_pred, y[:,:int(k*len(time_points)-1),:]) + lambdA*torch.sum((S_pred[:,-1,:] - y[:,int(k*len(time_points)-1),:])**2)
            loss.backward()
            optimizer.step()
            #print(f'Epoch {epoch}, training loss: {loss.item()}')
            # plt.figure()
            # plt.plot(time_points[:int(k*len(t))].numpy(), S_pred.squeeze(0).detach().numpy()[:,:,0])
            # plt.plot(time_points[:int(k*len(t))].numpy(), y[:,:int(k*len(t)),:].permute(1,0,2).squeeze(0).detach().numpy()[:,:,0])
            # plt.show()
            # if epoch % 100 == 0:
            #     print(f'Epoch {epoch}, Loss: {loss.item()}')
        model.eval()
        for i, (x, y) in enumerate(test_loader):
            #y = y.permute(1,0)
            S_pred = torch.squeeze(torchsde.sdeint(model, x, time_points[:int(k*len(time_points)-1)])).permute(1,0,2)
            loss = torch.sum((S_pred[:,-1,:] - y[:,int(k*len(time_points)-1),:])**2)
            print(f'Epoch {epoch}, validation loss: {loss.item()}')
            # plt.figure()
            # plt.plot(time_points[:int(k*len(t))].numpy(), S_pred.detach().numpy()[:,0,0])
            # plt.plot(time_points[:int(k*len(t))].numpy(), y[:int(k*len(t)),0,0])
            # plt.show()
            # if epoch % 100 == 0:
            #     print(f'Epoch {epoch}, Loss: {loss.item()}')

        if (epoch+1)%delta_k_epoch==0:
            k = 1
        if k==1:
            if loss<min_loss:
                torch.save(model.state_dict(), f"best_NSDE_model/model_{args['id']}.pt")
                min_loss = loss
                patience = 0
            else:
                patience += 1
                if patience>= max_patience:
                    break
    
    model.load_state_dict(torch.load(f"best_NSDE_model/model_{args['id']}.pt"))
    model.eval()
    
    for i, (x, y) in enumerate(valid_loader):
        #y = y.permute(1,0)
        S_pred = torch.squeeze(torchsde.sdeint(model, x, time_points)).permute(1,0,2)
        loss = torch.sum((S_pred[:,-1,:] - y[:,-1,:])**2)
        print(f'Epoch {epoch}, test loss: {loss.item()}')
    

    record = {
        'id' : [args['id']],
        'mu': [mu],
        'sigma' : [sigma],
        'MSE_valid' : [loss.detach().numpy()],
        'MSE_test' : [min_loss.detach().numpy()]
        }
        
    torch.save(model.state_dict(), f"best_NSDE_model/model_{args['id']}.pt")
    df = pd.DataFrame(record)
    df.to_csv('NSDE_dyn_portfolio_results.csv',mode='a', header=False, index=False)
